{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization] pandas scikit-learn numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MPG Simple Advanced Target [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleAdvancedTarget.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleAdvancedTarget.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2024-07-20T17:55:33.133151Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f7ca0a2e-99c4-49de-af45-c8c4bddf5685",
     "showTitle": false,
     "title": ""
    },
    "jupyter": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from hamilton import driver\n",
    "from IPython.display import HTML, display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:32:57.962238Z",
     "start_time": "2024-07-20T17:32:57.947142Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "45fcd1cf-5dee-4d3c-b598-823c82654805",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%load_ext hamilton.plugins.jupyter_magic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:00.825770Z",
     "start_time": "2024-07-20T17:37:00.183488Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "155ea802-aef6-4d5c-b264-d9ec5b57c733",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"865pt\" height=\"304pt\"\n",
       " viewBox=\"0.00 0.00 864.60 303.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 299.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-299.8 860.6,-299.8 860.6,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"31.38,-157.8 31.38,-287.8 116.22,-287.8 116.22,-157.8 31.38,-157.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-270.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- target_column -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>target_column</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M392.8,-256.6C392.8,-256.6 301.45,-256.6 301.45,-256.6 295.45,-256.6 289.45,-250.6 289.45,-244.6 289.45,-244.6 289.45,-205 289.45,-205 289.45,-199 295.45,-193 301.45,-193 301.45,-193 392.8,-193 392.8,-193 398.8,-193 404.8,-199 404.8,-205 404.8,-205 404.8,-244.6 404.8,-244.6 404.8,-250.6 398.8,-256.6 392.8,-256.6\"/>\n",
       "<text text-anchor=\"start\" x=\"300.25\" y=\"-233.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n",
       "<text text-anchor=\"start\" x=\"339.62\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M844.6,-160.6C844.6,-160.6 736.75,-160.6 736.75,-160.6 730.75,-160.6 724.75,-154.6 724.75,-148.6 724.75,-148.6 724.75,-109 724.75,-109 724.75,-103 730.75,-97 736.75,-97 736.75,-97 844.6,-97 844.6,-97 850.6,-97 856.6,-103 856.6,-109 856.6,-109 856.6,-148.6 856.6,-148.6 856.6,-154.6 850.6,-160.6 844.6,-160.6\"/>\n",
       "<text text-anchor=\"start\" x=\"735.55\" y=\"-137.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"780.18\" y=\"-109.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.29,-233.76C476.21,-242.38 599.51,-249.12 695.75,-210.8 718.39,-201.78 739.28,-184.84 755.58,-168.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"757.68,-171.55 762.15,-161.94 752.67,-166.66 757.68,-171.55\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M527.4,-183.6C527.4,-183.6 445.8,-183.6 445.8,-183.6 439.8,-183.6 433.8,-177.6 433.8,-171.6 433.8,-171.6 433.8,-132 433.8,-132 433.8,-126 439.8,-120 445.8,-120 445.8,-120 527.4,-120 527.4,-120 533.4,-120 539.4,-126 539.4,-132 539.4,-132 539.4,-171.6 539.4,-171.6 539.4,-177.6 533.4,-183.6 527.4,-183.6\"/>\n",
       "<text text-anchor=\"start\" x=\"444.6\" y=\"-160.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"476.1\" y=\"-132.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;linear_model -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.23,-194.48C411.28,-191.26 417.45,-187.99 423.56,-184.75\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"425.2,-187.84 432.39,-180.06 421.91,-181.65 425.2,-187.84\"/>\n",
       "</g>\n",
       "<!-- lr_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>lr_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M683.75,-201.6C683.75,-201.6 580.4,-201.6 580.4,-201.6 574.4,-201.6 568.4,-195.6 568.4,-189.6 568.4,-189.6 568.4,-150 568.4,-150 568.4,-144 574.4,-138 580.4,-138 580.4,-138 683.75,-138 683.75,-138 689.75,-138 695.75,-144 695.75,-150 695.75,-150 695.75,-189.6 695.75,-189.6 695.75,-195.6 689.75,-201.6 683.75,-201.6\"/>\n",
       "<text text-anchor=\"start\" x=\"603.58\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n",
       "<text text-anchor=\"start\" x=\"579.2\" y=\"-150.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n",
       "</g>\n",
       "<!-- lr_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>lr_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M695.87,-153.36C701.57,-151.87 707.39,-150.34 713.19,-148.82\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"714.06,-152.22 722.84,-146.3 712.28,-145.44 714.06,-152.22\"/>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M248.45,-105.6C248.45,-105.6 188.6,-105.6 188.6,-105.6 182.6,-105.6 176.6,-99.6 176.6,-93.6 176.6,-93.6 176.6,-54 176.6,-54 176.6,-48 182.6,-42 188.6,-42 188.6,-42 248.45,-42 248.45,-42 254.45,-42 260.45,-48 260.45,-54 260.45,-54 260.45,-93.6 260.45,-93.6 260.45,-99.6 254.45,-105.6 248.45,-105.6\"/>\n",
       "<text text-anchor=\"start\" x=\"187.4\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"208.02\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.18,-102.74C133.93,-98.4 150.48,-93.53 165.74,-89.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"166.52,-92.46 175.13,-86.28 164.54,-85.74 166.52,-92.46\"/>\n",
       "</g>\n",
       "<!-- test_dataset -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>test_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M385.3,-64.6C385.3,-64.6 308.95,-64.6 308.95,-64.6 302.95,-64.6 296.95,-58.6 296.95,-52.6 296.95,-52.6 296.95,-13 296.95,-13 296.95,-7 302.95,-1 308.95,-1 308.95,-1 385.3,-1 385.3,-1 391.3,-1 397.3,-7 397.3,-13 397.3,-13 397.3,-52.6 397.3,-52.6 397.3,-58.6 391.3,-64.6 385.3,-64.6\"/>\n",
       "<text text-anchor=\"start\" x=\"307.75\" y=\"-41.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"313\" y=\"-13.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;test_dataset -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;test_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-60.49C268.75,-57.86 277.43,-55.05 286.03,-52.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"286.96,-55.64 295.4,-49.23 284.81,-48.98 286.96,-55.64\"/>\n",
       "</g>\n",
       "<!-- train_dataset -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>train_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M387.93,-165.6C387.93,-165.6 306.32,-165.6 306.32,-165.6 300.32,-165.6 294.32,-159.6 294.32,-153.6 294.32,-153.6 294.32,-114 294.32,-114 294.32,-108 300.32,-102 306.32,-102 306.32,-102 387.93,-102 387.93,-102 393.93,-102 399.93,-108 399.93,-114 399.93,-114 399.93,-153.6 399.93,-153.6 399.93,-159.6 393.93,-165.6 387.93,-165.6\"/>\n",
       "<text text-anchor=\"start\" x=\"305.12\" y=\"-142.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">train_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"313\" y=\"-114.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;train_dataset -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;train_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-93.28C268.06,-96.8 275.96,-100.54 283.83,-104.28\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"282.23,-107.39 292.76,-108.51 285.23,-101.07 282.23,-107.39\"/>\n",
       "</g>\n",
       "<!-- scaler -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>scaler</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M677.38,-119.6C677.38,-119.6 586.78,-119.6 586.78,-119.6 580.78,-119.6 574.78,-113.6 574.78,-107.6 574.78,-107.6 574.78,-68 574.78,-68 574.78,-62 580.78,-56 586.78,-56 586.78,-56 677.38,-56 677.38,-56 683.38,-56 689.38,-62 689.38,-68 689.38,-68 689.38,-107.6 689.38,-107.6 689.38,-113.6 683.38,-119.6 677.38,-119.6\"/>\n",
       "<text text-anchor=\"start\" x=\"612.58\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n",
       "<text text-anchor=\"start\" x=\"585.58\" y=\"-68.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n",
       "</g>\n",
       "<!-- scaler&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>scaler&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M689.61,-102.6C697.38,-104.64 705.46,-106.75 713.51,-108.86\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"712.41,-112.19 722.97,-111.34 714.19,-105.42 712.41,-112.19\"/>\n",
       "</g>\n",
       "<!-- test_dataset&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>test_dataset&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M397.75,-24.79C467.18,-15.77 596.01,-7.08 695.75,-46.8 718.39,-55.82 739.28,-72.76 755.58,-88.92\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"752.67,-90.94 762.15,-95.66 757.68,-86.05 752.67,-90.94\"/>\n",
       "</g>\n",
       "<!-- train_dataset&#45;&gt;linear_model -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>train_dataset&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M400.1,-140.61C407.37,-141.56 414.91,-142.54 422.36,-143.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"421.68,-146.96 432.05,-144.79 422.59,-140.02 421.68,-146.96\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;lr_model -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;lr_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-158.35C545.37,-159.05 551.11,-159.77 556.86,-160.49\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"556.06,-163.92 566.42,-161.69 556.93,-156.97 556.06,-163.92\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;scaler -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;scaler</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-128.51C547.69,-125 555.92,-121.32 564.07,-117.69\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"565.22,-121.01 572.93,-113.74 562.37,-114.62 565.22,-121.01\"/>\n",
       "</g>\n",
       "<!-- _data_sets_inputs -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>_data_sets_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n",
       "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n",
       "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n",
       "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n",
       "</g>\n",
       "<!-- _data_sets_inputs&#45;&gt;data_sets -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>_data_sets_inputs&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M147.83,-53.78C153.77,-55.49 159.69,-57.19 165.43,-58.83\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"164.17,-62.11 174.75,-61.51 166.1,-55.39 164.17,-62.11\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-202.1 46.8,-202.1 46.8,-165.5 100.8,-165.5 100.8,-202.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-178\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-257.1C96.22,-257.1 51.37,-257.1 51.37,-257.1 45.37,-257.1 39.37,-251.1 39.37,-245.1 39.37,-245.1 39.37,-232.5 39.37,-232.5 39.37,-226.5 45.37,-220.5 51.37,-220.5 51.37,-220.5 96.22,-220.5 96.22,-220.5 102.22,-220.5 108.22,-226.5 108.22,-232.5 108.22,-232.5 108.22,-245.1 108.22,-245.1 108.22,-251.1 102.22,-257.1 96.22,-257.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-233\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x1516f2d60>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%cell_to_module pipeline --display\n",
    "# when done you can write to file and then load it as a module normally\n",
    "# add -w to do so\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "from hamilton.function_modifiers import extract_fields\n",
    "\n",
    "\n",
    "def mpg_df() -> pd.DataFrame:\n",
    "    url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'\n",
    "    column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',\n",
    "                    'Acceleration', 'Model Year', 'Origin']\n",
    "\n",
    "    raw_dataset = pd.read_csv(url, names=column_names,\n",
    "                              na_values='?', comment='\\t',\n",
    "                              sep=' ', skipinitialspace=True)\n",
    "\n",
    "    ## some schema manipulation\n",
    "    _mpg_df = raw_dataset.rename(columns={\"Model Year\": \"ModelYear\"})\n",
    "    return _mpg_df\n",
    "\n",
    "\n",
    "@extract_fields({\"train_dataset\": pd.DataFrame, \"test_dataset\": pd.DataFrame})\n",
    "def data_sets(mpg_df: pd.DataFrame, train_test_split: float = 0.8, seed: int = 123) -> dict:\n",
    "    # Do some feature engineering / data cleaning to create the data sets\n",
    "    # one hot encode -- we know the encoding here.\n",
    "    for value, country in {1: \"USA\", 2: \"Europe\", 3: \"Japan\"}.items():\n",
    "        mpg_df[country] = np.where(mpg_df[\"Origin\"] == value, 1, 0)\n",
    "    raw_dataset = mpg_df.dropna()\n",
    "    # split the pandas dataframe into train and test\n",
    "    train_dataset = raw_dataset.sample(frac=train_test_split, random_state=seed)\n",
    "    test_dataset = raw_dataset.drop(train_dataset.index)\n",
    "\n",
    "    return {\"train_dataset\": train_dataset, \"test_dataset\": test_dataset}\n",
    "\n",
    "\n",
    "def target_column() -> str:\n",
    "    return \"MPG\"\n",
    "\n",
    "\n",
    "@extract_fields({\"lr_model\": LinearRegression, \"scaler\": StandardScaler})\n",
    "def linear_model(train_dataset: pd.DataFrame, target_column: str) -> dict:\n",
    "    # Fit the model\n",
    "    # pull out target\n",
    "    train_labels = train_dataset.pop(target_column)\n",
    "    # Convert boolean columns to integers for the model\n",
    "    bool_columns = train_dataset.select_dtypes(include=[bool]).columns\n",
    "    train_dataset[bool_columns] = train_dataset[bool_columns].astype(int)\n",
    "    # Normalize the features for the model\n",
    "    scaler = StandardScaler()\n",
    "    train_dataset_scaled = scaler.fit_transform(train_dataset)\n",
    "\n",
    "    # Initialize and fit the Linear Regression model\n",
    "    linear_model = LinearRegression()\n",
    "    linear_model.fit(train_dataset_scaled, train_labels)\n",
    "    return {\"lr_model\": linear_model, \"scaler\": scaler}\n",
    "\n",
    "\n",
    "def evaluated_model(lr_model: LinearRegression,\n",
    "                    scaler: StandardScaler,\n",
    "                    test_dataset: pd.DataFrame, target_column: str) -> dict:\n",
    "    # evaluate the model - pull out target\n",
    "    test_labels = test_dataset.pop(target_column)\n",
    "    # Evaluate the model\n",
    "    # convert boolean columns to integers for the model\n",
    "    bool_columns = test_dataset.select_dtypes(include=[bool]).columns\n",
    "    test_dataset[bool_columns] = test_dataset[bool_columns].astype(int)\n",
    "    test_dataset_scaled = scaler.transform(test_dataset)\n",
    "\n",
    "    # Predict and evaluate the model\n",
    "    test_pred = lr_model.predict(test_dataset_scaled)\n",
    "    mae = mean_absolute_error(test_labels, test_pred)\n",
    "    test_results = {\n",
    "        \"linear_model\": mae\n",
    "    }\n",
    "    return test_results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:04.415885Z",
     "start_time": "2024-07-20T17:37:04.097149Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "17b10355-4e25-4e75-84c5-fd95c0bd3dfb",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"865pt\" height=\"304pt\"\n",
       " viewBox=\"0.00 0.00 864.60 303.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 299.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-299.8 860.6,-299.8 860.6,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"31.38,-157.8 31.38,-287.8 116.22,-287.8 116.22,-157.8 31.38,-157.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-270.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- target_column -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>target_column</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M392.8,-256.6C392.8,-256.6 301.45,-256.6 301.45,-256.6 295.45,-256.6 289.45,-250.6 289.45,-244.6 289.45,-244.6 289.45,-205 289.45,-205 289.45,-199 295.45,-193 301.45,-193 301.45,-193 392.8,-193 392.8,-193 398.8,-193 404.8,-199 404.8,-205 404.8,-205 404.8,-244.6 404.8,-244.6 404.8,-250.6 398.8,-256.6 392.8,-256.6\"/>\n",
       "<text text-anchor=\"start\" x=\"300.25\" y=\"-233.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n",
       "<text text-anchor=\"start\" x=\"339.62\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M844.6,-160.6C844.6,-160.6 736.75,-160.6 736.75,-160.6 730.75,-160.6 724.75,-154.6 724.75,-148.6 724.75,-148.6 724.75,-109 724.75,-109 724.75,-103 730.75,-97 736.75,-97 736.75,-97 844.6,-97 844.6,-97 850.6,-97 856.6,-103 856.6,-109 856.6,-109 856.6,-148.6 856.6,-148.6 856.6,-154.6 850.6,-160.6 844.6,-160.6\"/>\n",
       "<text text-anchor=\"start\" x=\"735.55\" y=\"-137.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"780.18\" y=\"-109.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.29,-233.76C476.21,-242.38 599.51,-249.12 695.75,-210.8 718.39,-201.78 739.28,-184.84 755.58,-168.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"757.68,-171.55 762.15,-161.94 752.67,-166.66 757.68,-171.55\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M527.4,-183.6C527.4,-183.6 445.8,-183.6 445.8,-183.6 439.8,-183.6 433.8,-177.6 433.8,-171.6 433.8,-171.6 433.8,-132 433.8,-132 433.8,-126 439.8,-120 445.8,-120 445.8,-120 527.4,-120 527.4,-120 533.4,-120 539.4,-126 539.4,-132 539.4,-132 539.4,-171.6 539.4,-171.6 539.4,-177.6 533.4,-183.6 527.4,-183.6\"/>\n",
       "<text text-anchor=\"start\" x=\"444.6\" y=\"-160.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"476.1\" y=\"-132.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;linear_model -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.23,-194.48C411.28,-191.26 417.45,-187.99 423.56,-184.75\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"425.2,-187.84 432.39,-180.06 421.91,-181.65 425.2,-187.84\"/>\n",
       "</g>\n",
       "<!-- lr_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>lr_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M683.75,-201.6C683.75,-201.6 580.4,-201.6 580.4,-201.6 574.4,-201.6 568.4,-195.6 568.4,-189.6 568.4,-189.6 568.4,-150 568.4,-150 568.4,-144 574.4,-138 580.4,-138 580.4,-138 683.75,-138 683.75,-138 689.75,-138 695.75,-144 695.75,-150 695.75,-150 695.75,-189.6 695.75,-189.6 695.75,-195.6 689.75,-201.6 683.75,-201.6\"/>\n",
       "<text text-anchor=\"start\" x=\"603.58\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n",
       "<text text-anchor=\"start\" x=\"579.2\" y=\"-150.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n",
       "</g>\n",
       "<!-- lr_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>lr_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M695.87,-153.36C701.57,-151.87 707.39,-150.34 713.19,-148.82\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"714.06,-152.22 722.84,-146.3 712.28,-145.44 714.06,-152.22\"/>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M248.45,-105.6C248.45,-105.6 188.6,-105.6 188.6,-105.6 182.6,-105.6 176.6,-99.6 176.6,-93.6 176.6,-93.6 176.6,-54 176.6,-54 176.6,-48 182.6,-42 188.6,-42 188.6,-42 248.45,-42 248.45,-42 254.45,-42 260.45,-48 260.45,-54 260.45,-54 260.45,-93.6 260.45,-93.6 260.45,-99.6 254.45,-105.6 248.45,-105.6\"/>\n",
       "<text text-anchor=\"start\" x=\"187.4\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"208.02\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.18,-102.74C133.93,-98.4 150.48,-93.53 165.74,-89.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"166.52,-92.46 175.13,-86.28 164.54,-85.74 166.52,-92.46\"/>\n",
       "</g>\n",
       "<!-- test_dataset -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>test_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M385.3,-64.6C385.3,-64.6 308.95,-64.6 308.95,-64.6 302.95,-64.6 296.95,-58.6 296.95,-52.6 296.95,-52.6 296.95,-13 296.95,-13 296.95,-7 302.95,-1 308.95,-1 308.95,-1 385.3,-1 385.3,-1 391.3,-1 397.3,-7 397.3,-13 397.3,-13 397.3,-52.6 397.3,-52.6 397.3,-58.6 391.3,-64.6 385.3,-64.6\"/>\n",
       "<text text-anchor=\"start\" x=\"307.75\" y=\"-41.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"313\" y=\"-13.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;test_dataset -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;test_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-60.49C268.75,-57.86 277.43,-55.05 286.03,-52.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"286.96,-55.64 295.4,-49.23 284.81,-48.98 286.96,-55.64\"/>\n",
       "</g>\n",
       "<!-- train_dataset -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>train_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M387.93,-165.6C387.93,-165.6 306.32,-165.6 306.32,-165.6 300.32,-165.6 294.32,-159.6 294.32,-153.6 294.32,-153.6 294.32,-114 294.32,-114 294.32,-108 300.32,-102 306.32,-102 306.32,-102 387.93,-102 387.93,-102 393.93,-102 399.93,-108 399.93,-114 399.93,-114 399.93,-153.6 399.93,-153.6 399.93,-159.6 393.93,-165.6 387.93,-165.6\"/>\n",
       "<text text-anchor=\"start\" x=\"305.12\" y=\"-142.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">train_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"313\" y=\"-114.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;train_dataset -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;train_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M260.63,-93.28C268.06,-96.8 275.96,-100.54 283.83,-104.28\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"282.23,-107.39 292.76,-108.51 285.23,-101.07 282.23,-107.39\"/>\n",
       "</g>\n",
       "<!-- scaler -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>scaler</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M677.38,-119.6C677.38,-119.6 586.78,-119.6 586.78,-119.6 580.78,-119.6 574.78,-113.6 574.78,-107.6 574.78,-107.6 574.78,-68 574.78,-68 574.78,-62 580.78,-56 586.78,-56 586.78,-56 677.38,-56 677.38,-56 683.38,-56 689.38,-62 689.38,-68 689.38,-68 689.38,-107.6 689.38,-107.6 689.38,-113.6 683.38,-119.6 677.38,-119.6\"/>\n",
       "<text text-anchor=\"start\" x=\"612.58\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n",
       "<text text-anchor=\"start\" x=\"585.58\" y=\"-68.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n",
       "</g>\n",
       "<!-- scaler&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>scaler&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M689.61,-102.6C697.38,-104.64 705.46,-106.75 713.51,-108.86\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"712.41,-112.19 722.97,-111.34 714.19,-105.42 712.41,-112.19\"/>\n",
       "</g>\n",
       "<!-- test_dataset&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>test_dataset&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M397.75,-24.79C467.18,-15.77 596.01,-7.08 695.75,-46.8 718.39,-55.82 739.28,-72.76 755.58,-88.92\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"752.67,-90.94 762.15,-95.66 757.68,-86.05 752.67,-90.94\"/>\n",
       "</g>\n",
       "<!-- train_dataset&#45;&gt;linear_model -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>train_dataset&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M400.1,-140.61C407.37,-141.56 414.91,-142.54 422.36,-143.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"421.68,-146.96 432.05,-144.79 422.59,-140.02 421.68,-146.96\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;lr_model -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;lr_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-158.35C545.37,-159.05 551.11,-159.77 556.86,-160.49\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"556.06,-163.92 566.42,-161.69 556.93,-156.97 556.06,-163.92\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;scaler -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;scaler</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M539.8,-128.51C547.69,-125 555.92,-121.32 564.07,-117.69\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"565.22,-121.01 572.93,-113.74 562.37,-114.62 565.22,-121.01\"/>\n",
       "</g>\n",
       "<!-- _data_sets_inputs -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>_data_sets_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n",
       "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n",
       "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n",
       "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n",
       "</g>\n",
       "<!-- _data_sets_inputs&#45;&gt;data_sets -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>_data_sets_inputs&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M147.83,-53.78C153.77,-55.49 159.69,-57.19 165.43,-58.83\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"164.17,-62.11 174.75,-61.51 166.1,-55.39 164.17,-62.11\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-202.1 46.8,-202.1 46.8,-165.5 100.8,-165.5 100.8,-202.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-178\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-257.1C96.22,-257.1 51.37,-257.1 51.37,-257.1 45.37,-257.1 39.37,-251.1 39.37,-245.1 39.37,-245.1 39.37,-232.5 39.37,-232.5 39.37,-226.5 45.37,-220.5 51.37,-220.5 51.37,-220.5 96.22,-220.5 96.22,-220.5 102.22,-220.5 108.22,-226.5 108.22,-232.5 108.22,-232.5 108.22,-245.1 108.22,-245.1 108.22,-251.1 102.22,-257.1 96.22,-257.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-233\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<hamilton.driver.Driver at 0x1516d7100>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr = driver.Builder().with_modules(pipeline).build()\n",
    "dr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:06.029923Z",
     "start_time": "2024-07-20T17:37:05.934165Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "923eff9a-ce20-484e-a4c7-acfbebb58e16",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'evaluated_model': {'linear_model': 2.4926580150007007},\n",
       " 'linear_model': {'lr_model': LinearRegression(), 'scaler': StandardScaler()}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = dr.execute([\"evaluated_model\", \"linear_model\"])\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:36.535977Z",
     "start_time": "2024-07-20T17:37:36.241822Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "399b0164-819e-4c8a-a46d-021742ace28e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"607pt\" height=\"566pt\"\n",
       " viewBox=\"0.00 0.00 607.40 565.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 561.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-561.8 603.4,-561.8 603.4,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"30.62,-309.8 30.62,-549.8 116.97,-549.8 116.97,-309.8 30.62,-309.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-532.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- target_column -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>target_column</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M119.47,-299.6C119.47,-299.6 28.12,-299.6 28.12,-299.6 22.12,-299.6 16.12,-293.6 16.12,-287.6 16.12,-287.6 16.12,-248 16.12,-248 16.12,-242 22.12,-236 28.12,-236 28.12,-236 119.47,-236 119.47,-236 125.47,-236 131.47,-242 131.47,-248 131.47,-248 131.47,-287.6 131.47,-287.6 131.47,-293.6 125.47,-299.6 119.47,-299.6\"/>\n",
       "<text text-anchor=\"start\" x=\"26.92\" y=\"-276.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">target_column</text>\n",
       "<text text-anchor=\"start\" x=\"66.3\" y=\"-248.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M587.4,-228.6C587.4,-228.6 479.55,-228.6 479.55,-228.6 473.55,-228.6 467.55,-222.6 467.55,-216.6 467.55,-216.6 467.55,-177 467.55,-177 467.55,-171 473.55,-165 479.55,-165 479.55,-165 587.4,-165 587.4,-165 593.4,-165 599.4,-171 599.4,-177 599.4,-177 599.4,-216.6 599.4,-216.6 599.4,-222.6 593.4,-228.6 587.4,-228.6\"/>\n",
       "<text text-anchor=\"start\" x=\"478.35\" y=\"-205.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"522.98\" y=\"-177.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M131.81,-282.49C205.32,-298.56 335.92,-317.3 438.55,-278.8 461.54,-270.18 482.6,-253.09 498.92,-236.74\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"501.09,-239.54 505.49,-229.91 496.04,-234.69 501.09,-239.54\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"282.2,-269.6 176.6,-269.6 176.6,-206 282.2,-206 282.2,-269.6\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"188.6,-269.6 176.6,-257.6\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"176.6,-218 188.6,-206\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"270.2,-206 282.2,-218\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"282.2,-257.6 270.2,-269.6\"/>\n",
       "<text text-anchor=\"start\" x=\"187.4\" y=\"-246.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"218.9\" y=\"-218.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- target_column&#45;&gt;linear_model -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>target_column&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M131.56,-256.71C142.52,-254.57 154.06,-252.32 165.21,-250.14\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"165.57,-253.64 174.71,-248.29 164.23,-246.77 165.57,-253.64\"/>\n",
       "</g>\n",
       "<!-- lr_model -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>lr_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M426.55,-269.6C426.55,-269.6 323.2,-269.6 323.2,-269.6 317.2,-269.6 311.2,-263.6 311.2,-257.6 311.2,-257.6 311.2,-218 311.2,-218 311.2,-212 317.2,-206 323.2,-206 323.2,-206 426.55,-206 426.55,-206 432.55,-206 438.55,-212 438.55,-218 438.55,-218 438.55,-257.6 438.55,-257.6 438.55,-263.6 432.55,-269.6 426.55,-269.6\"/>\n",
       "<text text-anchor=\"start\" x=\"346.38\" y=\"-246.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lr_model</text>\n",
       "<text text-anchor=\"start\" x=\"322\" y=\"-218.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">LinearRegression</text>\n",
       "</g>\n",
       "<!-- lr_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>lr_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M438.67,-221.36C444.37,-219.87 450.19,-218.34 455.99,-216.82\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"456.86,-220.22 465.64,-214.3 455.08,-213.44 456.86,-220.22\"/>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M106.72,-147.6C106.72,-147.6 40.87,-147.6 40.87,-147.6 34.87,-147.6 28.87,-141.6 28.87,-135.6 28.87,-135.6 28.87,-96 28.87,-96 28.87,-90 34.87,-84 40.87,-84 40.87,-84 106.72,-84 106.72,-84 112.72,-84 118.72,-90 118.72,-96 118.72,-96 118.72,-135.6 118.72,-135.6 118.72,-141.6 112.72,-147.6 106.72,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"49.05\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"39.67\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M259.32,-105.6C259.32,-105.6 199.47,-105.6 199.47,-105.6 193.47,-105.6 187.47,-99.6 187.47,-93.6 187.47,-93.6 187.47,-54 187.47,-54 187.47,-48 193.47,-42 199.47,-42 199.47,-42 259.32,-42 259.32,-42 265.32,-42 271.32,-48 271.32,-54 271.32,-54 271.32,-93.6 271.32,-93.6 271.32,-99.6 265.32,-105.6 259.32,-105.6\"/>\n",
       "<text text-anchor=\"start\" x=\"198.27\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"218.9\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M118.82,-103.76C136.67,-98.88 157.42,-93.21 176.04,-88.12\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"176.83,-91.53 185.56,-85.52 174.99,-84.78 176.83,-91.53\"/>\n",
       "</g>\n",
       "<!-- test_dataset -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>test_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M413.05,-105.6C413.05,-105.6 336.7,-105.6 336.7,-105.6 330.7,-105.6 324.7,-99.6 324.7,-93.6 324.7,-93.6 324.7,-54 324.7,-54 324.7,-48 330.7,-42 336.7,-42 336.7,-42 413.05,-42 413.05,-42 419.05,-42 425.05,-48 425.05,-54 425.05,-54 425.05,-93.6 425.05,-93.6 425.05,-99.6 419.05,-105.6 413.05,-105.6\"/>\n",
       "<text text-anchor=\"start\" x=\"335.5\" y=\"-82.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"340.75\" y=\"-54.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;test_dataset -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;test_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M271.51,-73.8C284.52,-73.8 299.17,-73.8 313.24,-73.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"312.85,-77.3 322.85,-73.8 312.85,-70.3 312.85,-77.3\"/>\n",
       "</g>\n",
       "<!-- scaler -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>scaler</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M420.18,-187.6C420.18,-187.6 329.57,-187.6 329.57,-187.6 323.57,-187.6 317.57,-181.6 317.57,-175.6 317.57,-175.6 317.57,-136 317.57,-136 317.57,-130 323.57,-124 329.57,-124 329.57,-124 420.18,-124 420.18,-124 426.18,-124 432.18,-130 432.18,-136 432.18,-136 432.18,-175.6 432.18,-175.6 432.18,-181.6 426.18,-187.6 420.18,-187.6\"/>\n",
       "<text text-anchor=\"start\" x=\"355.38\" y=\"-164.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scaler</text>\n",
       "<text text-anchor=\"start\" x=\"328.38\" y=\"-136.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">StandardScaler</text>\n",
       "</g>\n",
       "<!-- scaler&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>scaler&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M432.41,-170.6C440.18,-172.64 448.26,-174.75 456.31,-176.86\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"455.21,-180.19 465.77,-179.34 456.99,-173.42 455.21,-180.19\"/>\n",
       "</g>\n",
       "<!-- test_dataset&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>test_dataset&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M425.25,-105.17C429.81,-108.34 434.3,-111.58 438.55,-114.8 455.87,-127.93 474.05,-143.37 489.76,-157.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"487.29,-159.78 497.08,-163.84 491.96,-154.57 487.29,-159.78\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;lr_model -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;lr_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M282.6,-237.8C288.03,-237.8 293.61,-237.8 299.22,-237.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"299.21,-241.3 309.21,-237.8 299.21,-234.3 299.21,-241.3\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;scaler -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;scaler</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M282.6,-207.96C290.78,-203.29 299.33,-198.41 307.76,-193.58\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"309.34,-196.72 316.28,-188.72 305.86,-190.64 309.34,-196.72\"/>\n",
       "</g>\n",
       "<!-- _data_sets_inputs -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>_data_sets_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"147.6,-65.6 0,-65.6 0,0 147.6,0 147.6,-65.6\"/>\n",
       "<text text-anchor=\"start\" x=\"43.67\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n",
       "<text text-anchor=\"start\" x=\"113.17\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"14.8\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">train_test_split</text>\n",
       "<text text-anchor=\"start\" x=\"107.55\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">float</text>\n",
       "</g>\n",
       "<!-- _data_sets_inputs&#45;&gt;data_sets -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>_data_sets_inputs&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M148,-52.34C157.45,-54.86 166.98,-57.41 176.01,-59.81\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"175.1,-63.19 185.66,-62.39 176.9,-56.43 175.1,-63.19\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"100.8,-354.1 46.8,-354.1 46.8,-317.5 100.8,-317.5 100.8,-354.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-330\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M96.22,-409.1C96.22,-409.1 51.37,-409.1 51.37,-409.1 45.37,-409.1 39.37,-403.1 39.37,-397.1 39.37,-397.1 39.37,-384.5 39.37,-384.5 39.37,-378.5 45.37,-372.5 51.37,-372.5 51.37,-372.5 96.22,-372.5 96.22,-372.5 102.22,-372.5 108.22,-378.5 108.22,-384.5 108.22,-384.5 108.22,-397.1 108.22,-397.1 108.22,-403.1 102.22,-409.1 96.22,-409.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-385\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- output -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>output</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M91.35,-464.1C91.35,-464.1 56.25,-464.1 56.25,-464.1 50.25,-464.1 44.25,-458.1 44.25,-452.1 44.25,-452.1 44.25,-439.5 44.25,-439.5 44.25,-433.5 50.25,-427.5 56.25,-427.5 56.25,-427.5 91.35,-427.5 91.35,-427.5 97.35,-427.5 103.35,-433.5 103.35,-439.5 103.35,-439.5 103.35,-452.1 103.35,-452.1 103.35,-458.1 97.35,-464.1 91.35,-464.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-440\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">output</text>\n",
       "</g>\n",
       "<!-- override -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>override</title>\n",
       "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"108.97,-519.1 38.62,-519.1 38.62,-482.5 108.97,-482.5 108.97,-519.1\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"50.62,-519.1 38.62,-507.1\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"38.62,-494.5 50.62,-482.5\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"96.97,-482.5 108.97,-494.5\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"108.97,-507.1 96.97,-519.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"73.8\" y=\"-495\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">override</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x151753040>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Visualize Overrides\n",
    "dr.visualize_execution([\"evaluated_model\"], \n",
    "                       overrides={\"linear_model\": result[\"linear_model\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:37.088527Z",
     "start_time": "2024-07-20T17:37:37.025195Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "b617038c-2ffa-498c-b6c9-bbd6bb79bb1f",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'evaluated_model': {'linear_model': 2.4926580150007007}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Execute with overrides\n",
    "dr.execute([\"evaluated_model\"], overrides={\"linear_model\": result[\"linear_model\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "mostRecentlyExecutedCommandWithImplicitDF": {
     "commandId": 2746022128672016,
     "dataframes": [
      "_sqldf"
     ]
    },
    "pythonIndentUnit": 4
   },
   "notebookName": "MPG Simple V1 Target",
   "widgets": {}
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
